import copy
import glob
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import sympy
import torch
import torch.nn as nn
from tqdm import tqdm

from .KANLayer import KANLayer
from .LBFGS import LBFGS
from .spline import curve2coef
from .Symbolic_KANLayer import SYMBOLIC_LIB, Symbolic_KANLayer


class KAN(nn.Module):
    """
    KAN class

    Attributes:
    -----------
        biases: a list of nn.Linear()
            biases are added on nodes (in principle, biases can be absorbed
            into activation functions. However, we still have them for better optimization)
        act_fun: a list of KANLayer
            KANLayers
        depth: int
            depth of KAN
        width: list
            number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs,
            5D outputs, with 2 layers of 5 hidden neurons.
        grid: int
            the number of grid intervals
        k: int
            the order of piecewise polynomial
        base_fun: fun
            residual function b(x). an activation function
            phi(x) = sb_scale * b(x) + sp_scale * spline(x)
        symbolic_fun: a list of Symbolic_KANLayer
            Symbolic_KANLayers
        symbolic_enabled: bool
            If False, the symbolic front is not computed (to save time).
            Default: True.

    Methods:
    --------
        __init__():
            initialize a KAN
        initialize_from_another_model():
            initialize a KAN from another KAN (with the same shape, but
            potentially different grids)
        update_grid_from_samples():
            update spline grids based on samples
        initialize_grid_from_another_model():
            initalize KAN grids from another KAN
        forward():
            forward
        set_mode():
            set the mode of an activation function: 'n' for numeric, 's' for
            symbolic, 'ns' for combined (note they are visualized differently
            in plot(). 'n' as black, 's' as red, 'ns' as purple).
        fix_symbolic():
            fix an activation function to be symbolic
        suggest_symbolic():
            suggest the symbolic candicates of a numeric spline-based
            activation function
        lock():
            lock activation functions to share parameters
        unlock():
            unlock locked activations
        get_range():
            get the input and output ranges of an activation function
        plot():
            plot the diagram of KAN
        train_kan():
            train KAN
        prune():
            prune KAN
        remove_edge():
            remove some edge of KAN
        remove_node():
            remove some node of KAN
        auto_symbolic():
            automatically fit all splines to be symbolic functions
        symbolic_formula():
            obtain the symbolic formula of the KAN network
    """

    def __init__(
        self,
        width=None,
        grid=3,
        k=3,
        noise_scale=0.1,
        noise_scale_base=0.1,
        base_fun=torch.nn.SiLU(),
        symbolic_enabled=True,
        bias_trainable=True,
        grid_eps=1.0,
        grid_range=[-1, 1],
        sp_trainable=True,
        sb_trainable=True,
        seed=0,
    ):
        """
        initalize a KAN model

        Args:
        -----
            width : list of int
                :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons
                in each layer (including inputs/outputs)
            grid : int
                number of grid intervals. Default: 3.
            k : int
                order of piecewise polynomial. Default: 3.
            noise_scale : float
                initial injected noise to spline. Default: 0.1.
            base_fun : fun
                the residual function b(x). Default: torch.nn.SiLU().
            symbolic_enabled : bool
                compute or skip symbolic computations (for efficiency).
                By default: True.
            bias_trainable : bool
                bias parameters are updated or not. By default: True
            grid_eps : float
                When grid_eps = 0, the grid is uniform; when grid_eps = 1,
                the grid is partitioned using percentiles of samples.
                0 < grid_eps < 1 interpolates between the two extremes.
                Default: 0.02.
            grid_range : list/np.array of shape (2,))
                setting the range of grids. Default: [-1,1].
            sp_trainable : bool
                If true, scale_sp is trainable. Default: True.
            sb_trainable : bool
                If true, scale_base is trainable. Default: True.
            seed : int
                random seed

        Returns:
        --------
            self

        Example
        -------
        >>> model = KAN(width=[2,5,1], grid=5, k=3)
        >>> (model.act_fun[0].in_dim, model.act_fun[0].out_dim), \
            (model.act_fun[1].in_dim, model.act_fun[1].out_dim)
        ((2, 5), (5, 1))
        """
        super(KAN, self).__init__()

        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        # initializeing the numerical front ###

        self.biases = []
        self.act_fun = []
        self.depth = len(width) - 1
        self.width = width

        for l in range(self.depth):
            # splines
            scale_base = (
                1 / np.sqrt(width[l])
                + (
                    torch.randn(
                        width[l] * width[l + 1],
                    )
                    * 2
                    - 1
                )
                * noise_scale_base
            )
            sp_batch = KANLayer(
                in_dim=width[l],
                out_dim=width[l + 1],
                num=grid,
                k=k,
                noise_scale=noise_scale,
                scale_base=scale_base,
                scale_sp=1.0,
                base_fun=base_fun,
                grid_eps=grid_eps,
                grid_range=grid_range,
                sp_trainable=sp_trainable,
                sb_trainable=sb_trainable,
            )
            self.act_fun.append(sp_batch)

            # bias
            bias = nn.Linear(width[l + 1], 1, bias=False).requires_grad_(
                bias_trainable
            )
            bias.weight.data *= 0.0
            self.biases.append(bias)

        self.biases = nn.ModuleList(self.biases)
        self.act_fun = nn.ModuleList(self.act_fun)

        self.grid = grid
        self.k = k
        self.base_fun = base_fun

        # initializing the symbolic front ###
        self.symbolic_fun = []
        for l in range(self.depth):
            sb_batch = Symbolic_KANLayer(in_dim=width[l], out_dim=width[l + 1])
            self.symbolic_fun.append(sb_batch)

        self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
        self.symbolic_enabled = symbolic_enabled

    def initialize_from_another_model(self, another_model, x):
        """
        initialize from a parent model. The parent has the same width as the
        current model but may have different grids.

        Args:
        -----
            another_model : KAN
                the parent model used to initialize the current model
            x : 2D torch.float
                inputs, shape (batch, input dimension)

        Returns:
        --------
            self : KAN

        Example
        -------
        >>> model_coarse = KAN(width=[2,5,1], grid=5, k=3)
        >>> model_fine = KAN(width=[2,5,1], grid=10, k=3)
        >>> print(model_fine.act_fun[0].coef[0][0].data)
        >>> x = torch.normal(0,1,size=(100,2))
        >>> model_fine.initialize_from_another_model(model_coarse, x);
        >>> print(model_fine.act_fun[0].coef[0][0].data)
        tensor(-0.0030)
        tensor(0.0506)
        """
        another_model(x)  # get activations
        batch = x.shape[0]

        self.initialize_grid_from_another_model(another_model, x)

        for l in range(self.depth):
            spb = self.act_fun[l]
            spb_parent = another_model.act_fun[l]

            # spb = spb_parent
            preacts = another_model.spline_preacts[l]
            postsplines = another_model.spline_postsplines[l]
            self.act_fun[l].coef.data = curve2coef(
                preacts.reshape(batch, spb.size).permute(1, 0),
                postsplines.reshape(batch, spb.size).permute(1, 0),
                spb.grid,
                k=spb.k,
            )
            spb.scale_base.data = spb_parent.scale_base.data
            spb.scale_sp.data = spb_parent.scale_sp.data
            spb.mask.data = spb_parent.mask.data
            # print(spb.mask.data, self.act_fun[l].mask.data)

        for l in range(self.depth):
            self.biases[l].weight.data = another_model.biases[l].weight.data

        for l in range(self.depth):
            self.symbolic_fun[l] = another_model.symbolic_fun[l]

        return self

    def update_grid_from_samples(self, x):
        """
        update grid from samples

        Args:
        -----
            x : 2D torch.float
                inputs, shape (batch, input dimension)

        Returns:
        --------
            None

        Example
        -------
        >>> model = KAN(width=[2,5,1], grid=5, k=3)
        >>> print(model.act_fun[0].grid[0].data)
        >>> x = torch.rand(100,2)*5
        >>> model.update_grid_from_samples(x)
        >>> print(model.act_fun[0].grid[0].data)
        tensor([-1.0000, -0.6000, -0.2000,  0.2000,  0.6000,  1.0000])
        tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
        """
        for l in range(self.depth):
            self.forward(x)
            self.act_fun[l].update_grid_from_samples(self.acts[l])

    def initialize_grid_from_another_model(self, model, x):
        """
        initialize grid from a parent model

        Args:
        -----
            model : KAN
                parent model
            x : 2D torch.float
                inputs, shape (batch, input dimension)

        Returns:
        --------
            None

        Example
        -------
        >>> model_parent = KAN(width=[1,1], grid=5, k=3)
        >>> model_parent.act_fun[0].grid.data = torch.linspace(-2,2,steps=6)\
            [None,:]
        >>> x = torch.linspace(-2,2,steps=1001)[:,None]
        >>> model = KAN(width=[1,1], grid=5, k=3)
        >>> print(model.act_fun[0].grid.data)
        >>> model = model.initialize_from_another_model(model_parent, x)
        >>> print(model.act_fun[0].grid.data)
        tensor([[-1.0000, -0.6000, -0.2000,  0.2000,  0.6000,  1.0000]])
        tensor([[-2.0000, -1.2000, -0.4000,  0.4000,  1.2000,  2.0000]])
        """
        model(x)
        for l in range(self.depth):
            self.act_fun[l].initialize_grid_from_parent(
                model.act_fun[l], model.acts[l]
            )

    def forward(self, x):
        """
        KAN forward

        Args:
        -----
            x : 2D torch.float
                inputs, shape (batch, input dimension)

        Returns:
        --------
            y : 2D torch.float
                outputs, shape (batch, output dimension)

        Example
        -------
        >>> model = KAN(width=[2,5,3], grid=5, k=3)
        >>> x = torch.normal(0,1,size=(100,2))
        >>> model(x).shape
        torch.Size([100, 3])
        """

        shape_size = len(x.shape)

        if shape_size == 3:
            B, C, T = x.shape
        elif shape_size == 2:
            B, T = x.shape
        else:
            raise NotImplementedError()

        x = x.view(-1, T)

        self.acts = []  # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
        self.spline_preacts = []
        self.spline_postsplines = []
        self.spline_postacts = []
        self.acts_scale = []
        self.acts_scale_std = []
        # self.neurons_scale = []

        self.acts.append(x)  # acts shape: (batch, width[l])

        for l in range(self.depth):

            x_numerical, preacts, postacts_numerical, postspline = (
                self.act_fun[l](x)
            )

            if self.symbolic_enabled:
                x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
            else:
                x_symbolic = 0.0
                postacts_symbolic = 0.0

            x = x_numerical + x_symbolic
            postacts = postacts_numerical + postacts_symbolic

            # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
            grid_reshape = self.act_fun[l].grid.reshape(
                self.width[l + 1], self.width[l], -1
            )
            input_range = grid_reshape[:, :, -1] - grid_reshape[:, :, 0] + 1e-4
            output_range = torch.mean(torch.abs(postacts), dim=0)
            self.acts_scale.append(output_range / input_range)
            self.acts_scale_std.append(torch.std(postacts, dim=0))
            self.spline_preacts.append(preacts.detach())
            self.spline_postacts.append(postacts.detach())
            self.spline_postsplines.append(postspline.detach())

            x = x + self.biases[l].weight
            self.acts.append(x)

        U = x.shape[1]

        if shape_size == 3:
            x = x.view(B, C, U)
        elif shape_size == 2:
            assert x.shape == (B, U)

        return x

    def set_mode(self, l, i, j, mode, mask_n=None):
        """
        set (l,i,j) activation to have mode

        Args:
        -----
            l : int
                layer index
            i : int
                input neuron index
            j : int
                output neuron index
            mode : str
                'n' (numeric) or 's' (symbolic) or 'ns' (combined)
            mask_n : None or float)
                magnitude of the numeric front

        Returns:
        --------
            None
        """
        if mode == "s":
            mask_n = 0.0
            mask_s = 1.0
        elif mode == "n":
            mask_n = 1.0
            mask_s = 0.0
        elif mode == "sn" or mode == "ns":
            if mask_n is None:
                mask_n = 1.0
            else:
                mask_n = mask_n
            mask_s = 1.0
        else:
            mask_n = 0.0
            mask_s = 0.0

        self.act_fun[l].mask.data[j * self.act_fun[l].in_dim + i] = mask_n
        self.symbolic_fun[l].mask.data[j, i] = mask_s

    def fix_symbolic(
        self,
        l,
        i,
        j,
        fun_name,
        fit_params_bool=True,
        a_range=(-10, 10),
        b_range=(-10, 10),
        verbose=True,
        random=False,
    ):
        """
        set (l,i,j) activation to be symbolic (specified by fun_name)

        Args:
        -----
            l : int
                layer index
            i : int
                input neuron index
            j : int
                output neuron index
            fun_name : str
                function name
            fit_params_bool : bool
                obtaining affine parameters through fitting (True) or setting
                default values (False)
            a_range : tuple
                sweeping range of a
            b_range : tuple
                sweeping range of b
            verbose : bool
                If True, more information is printed.
            random : bool
                initialize affine parameteres randomly or as [1,0,1,0]

        Returns:
        --------
            None or r2 (coefficient of determination)

        Example 1
        ---------
        >>> # when fit_params_bool = False
        >>> model = KAN(width=[2,5,1], grid=5, k=3)
        >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
        >>> print(model.act_fun[0].mask.reshape(2,5))
        >>> print(model.symbolic_fun[0].mask.reshape(2,5))
        tensor([[1., 1., 1., 1., 1.],
                [1., 1., 0., 1., 1.]])
        tensor([[0., 0., 0., 0., 0.],
                [0., 0., 1., 0., 0.]])

        Example 2
        ---------
        >>> # when fit_params_bool = True
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
        >>> x = torch.normal(0,1,size=(100,2))
        >>> model(x) # obtain activations (otherwise model does not have \
        attributes acts)
        >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
        >>> print(model.act_fun[0].mask.reshape(2,5))
        >>> print(model.symbolic_fun[0].mask.reshape(2,5))
        r2 is 0.8131332993507385
        r2 is not very high, please double check if you are choosing the \
            correct symbolic function.
        tensor([[1., 1., 1., 1., 1.],
                [1., 1., 0., 1., 1.]])
        tensor([[0., 0., 0., 0., 0.],
                [0., 0., 1., 0., 0.]])
        """
        self.set_mode(l, i, j, mode="s")
        if not fit_params_bool:
            self.symbolic_fun[l].fix_symbolic(
                i, j, fun_name, verbose=verbose, random=random
            )
            return None
        else:
            x = self.acts[l][:, i]
            y = self.spline_postacts[l][:, j, i]
            r2 = self.symbolic_fun[l].fix_symbolic(
                i,
                j,
                fun_name,
                x,
                y,
                a_range=a_range,
                b_range=b_range,
                verbose=verbose,
            )
            return r2

    def unfix_symbolic(self, l, i, j):
        """
        unfix the (l,i,j) activation function.
        """
        self.set_mode(l, i, j, mode="n")

    def unfix_symbolic_all(self):
        """
        unfix all activation functions.
        """
        for l in range(len(self.width) - 1):
            for i in range(self.width[l]):
                for j in range(self.width[l + 1]):
                    self.unfix_symbolic(l, i, j)

    def lock(self, l, ids):
        """
        lock ids in the l-th layer to be the same function

        Args:
        -----
            l : int
                layer index
            ids : 2D list
                :math:`[[i_1,j_1],[i_2,j_2],...]` set :math:`(l,i_i,j_1), \
                    (l,i_2,j_2), ...` to be the same function

        Returns:
        --------
            None

        Example
        -------
        >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
        >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
        >>> model.lock(0,[[1,0],[1,1]])
        >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
        tensor([[0, 1],
                [2, 3],
                [4, 5]])
        tensor([[0, 1],
                [2, 1],
                [4, 5]])
        """
        self.act_fun[l].lock(ids)

    def unlock(self, l, ids):
        """
        unlock ids in the l-th layer to be the same function

        Args:
        -----
            l : int
                layer index
            ids : 2D list)
                [[i1,j1],[i2,j2],...] set (l,ii,j1), (l,i2,j2), ... to be
                unlocked

        Example:
        --------
        >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
        >>> model.lock(0,[[1,0],[1,1]])
        >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
        >>> model.unlock(0,[[1,0],[1,1]])
        >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
        tensor([[0, 1],
                [2, 1],
                [4, 5]])
        tensor([[0, 1],
                [2, 3],
                [4, 5]])
        """
        self.act_fun[l].unlock(ids)

    def get_range(self, l, i, j, verbose=True):
        """
        Get the input range and output range of the (l,i,j) activation

        Args:
        -----
            l : int
                layer index
            i : int
                input neuron index
            j : int
                output neuron index

        Returns:
        --------
            x_min : float
                minimum of input
            x_max : float
                maximum of input
            y_min : float
                minimum of output
            y_max : float
                maximum of output

        Example
        -------
        >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
        >>> x = torch.normal(0,1,size=(100,2))
        >>> model(x) # do a forward pass to obtain model.acts
        >>> model.get_range(0,0,0)
        x range: [-2.13 , 2.75 ]
        y range: [-0.50 , 1.83 ]
        (tensor(-2.1288), tensor(2.7498), tensor(-0.5042), tensor(1.8275))
        """
        x = self.spline_preacts[l][:, j, i]
        y = self.spline_postacts[l][:, j, i]
        x_min = torch.min(x)
        x_max = torch.max(x)
        y_min = torch.min(y)
        y_max = torch.max(y)
        if verbose:
            print("x range: [" + "%.2f" % x_min, ",", "%.2f" % x_max, "]")
            print("y range: [" + "%.2f" % y_min, ",", "%.2f" % y_max, "]")
        return x_min, x_max, y_min, y_max

    def plot(
        self,
        folder="./figures",
        beta=3,
        mask=False,
        mode="supervised",
        scale=0.5,
        tick=False,
        sample=False,
        in_vars=None,
        out_vars=None,
        title=None,
    ):
        """
        plot KAN

        Args:
        -----
            folder : str
                the folder to store pngs
            beta : float
                positive number. control the transparency of each activation.
                transparency = tanh(beta*l1).
            mask : bool
                If True, plot with mask (need to run prune() first to obtain
                mask). If False (by default), plot all activation functions.
            mode : bool
                "supervised" or "unsupervised". If "supervised", l1 is
                measured by absolution value (not subtracting mean);
                if "unsupervised", l1 is measured by standard deviation
                (subtracting mean).
            scale : float
                control the size of the diagram
            in_vars: None or list of str
                the name(s) of input variables
            out_vars: None or list of str
                the name(s) of output variables
            title: None or str
                title

        Returns:
        --------
            Figure

        Example
        -------
        >>> # see more interactive examples in demos
        >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
        >>> x = torch.normal(0,1,size=(100,2))
        >>> model(x) # do a forward pass to obtain model.acts
        >>> model.plot()
        """
        if not os.path.exists(folder):
            os.makedirs(folder)
        # matplotlib.use('Agg')
        depth = len(self.width) - 1
        for l in range(depth):
            w_large = 2.0
            for i in range(self.width[l]):
                for j in range(self.width[l + 1]):
                    rank = torch.argsort(self.acts[l][:, i])
                    fig, ax = plt.subplots(figsize=(w_large, w_large))

                    num = rank.shape[0]  # noqa

                    symbol_mask = self.symbolic_fun[l].mask[j][i]
                    numerical_mask = self.act_fun[l].mask.reshape(
                        self.width[l + 1], self.width[l]
                    )[j][i]
                    if symbol_mask > 0.0 and numerical_mask > 0.0:
                        color = "purple"
                        alpha_mask = 1
                    if symbol_mask > 0.0 and numerical_mask == 0.0:
                        color = "red"
                        alpha_mask = 1
                    if symbol_mask == 0.0 and numerical_mask > 0.0:
                        color = "black"
                        alpha_mask = 1
                    if symbol_mask == 0.0 and numerical_mask == 0.0:
                        color = "white"
                        alpha_mask = 0

                    if tick:
                        ax.tick_params(
                            axis="y", direction="in", pad=-22, labelsize=50
                        )
                        ax.tick_params(
                            axis="x", direction="in", pad=-15, labelsize=50
                        )
                        x_min, x_max, y_min, y_max = self.get_range(
                            l, i, j, verbose=False
                        )
                        plt.xticks(
                            [x_min, x_max], ["%2.f" % x_min, "%2.f" % x_max]
                        )
                        plt.yticks(
                            [y_min, y_max], ["%2.f" % y_min, "%2.f" % y_max]
                        )
                    else:
                        plt.xticks([])
                        plt.yticks([])
                    if alpha_mask == 1:
                        plt.gca().patch.set_edgecolor("black")
                    else:
                        plt.gca().patch.set_edgecolor("white")
                    plt.gca().patch.set_linewidth(1.5)
                    # plt.axis('off')

                    plt.plot(
                        self.acts[l][:, i][rank].cpu().detach().numpy(),
                        self.spline_postacts[l][:, j, i][rank]
                        .cpu()
                        .detach()
                        .numpy(),
                        color=color,
                        lw=5,
                    )
                    if sample:
                        plt.scatter(
                            self.acts[l][:, i][rank].cpu().detach().numpy(),
                            self.spline_postacts[l][:, j, i][rank]
                            .cpu()
                            .detach()
                            .numpy(),
                            color=color,
                            s=400 * scale**2,
                        )
                    plt.gca().spines[:].set_color(color)

                    lock_id = (
                        self.act_fun[l]
                        .lock_id[j * self.width[l] + i]
                        .long()
                        .item()
                    )
                    if lock_id > 0:
                        im = plt.imread(f"{folder}/lock.png")
                        newax = fig.add_axes([0.15, 0.7, 0.15, 0.15])
                        plt.text(500, 400, lock_id, fontsize=15)
                        newax.imshow(im)
                        newax.axis("off")

                    plt.savefig(
                        f"{folder}/sp_{l}_{i}_{j}.png",
                        bbox_inches="tight",
                        dpi=400,
                    )
                    plt.close()

        def score2alpha(score):
            return np.tanh(beta * score)

        if mode == "supervised":
            alpha = [
                score2alpha(score.cpu().detach().numpy())
                for score in self.acts_scale
            ]
        elif mode == "unsupervised":
            alpha = [
                score2alpha(score.cpu().detach().numpy())
                for score in self.acts_scale_std
            ]

        # draw skeleton
        width = np.array(self.width)
        A = 1
        y0 = 0.4  # 0.4

        # plt.figure(figsize=(5,5*(neuron_depth-1)*y0))
        neuron_depth = len(width)
        min_spacing = A / np.maximum(np.max(width), 5)

        max_neuron = np.max(width)  # noqa
        max_num_weights = np.max(width[:-1] * width[1:])
        y1 = 0.4 / np.maximum(max_num_weights, 3)

        fig, ax = plt.subplots(
            figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * y0)
        )
        # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))

        # plot scatters and lines
        for l in range(neuron_depth):
            n = width[l]
            spacing = A / n  # noqa
            for i in range(n):
                plt.scatter(
                    1 / (2 * n) + i / n,
                    l * y0,
                    s=min_spacing**2 * 10000 * scale**2,
                    color="black",
                )

                if l < neuron_depth - 1:
                    # plot connections
                    n_next = width[l + 1]
                    N = n * n_next
                    for j in range(n_next):
                        id_ = i * n_next + j

                        symbol_mask = self.symbolic_fun[l].mask[j][i]
                        numerical_mask = self.act_fun[l].mask.reshape(
                            self.width[l + 1], self.width[l]
                        )[j][i]
                        if symbol_mask == 1.0 and numerical_mask == 1.0:
                            color = "purple"
                            alpha_mask = 1.0
                        if symbol_mask == 1.0 and numerical_mask == 0.0:
                            color = "red"
                            alpha_mask = 1.0
                        if symbol_mask == 0.0 and numerical_mask == 1.0:
                            color = "black"
                            alpha_mask = 1.0
                        if symbol_mask == 0.0 and numerical_mask == 0.0:
                            color = "white"
                            alpha_mask = 0.0
                        if mask:
                            plt.plot(
                                [1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N],
                                [l * y0, (l + 1 / 2) * y0 - y1],
                                color=color,
                                lw=2 * scale,
                                alpha=alpha[l][j][i]
                                * self.mask[l][i].item()
                                * self.mask[l + 1][j].item(),
                            )
                            plt.plot(
                                [
                                    1 / (2 * N) + id_ / N,
                                    1 / (2 * n_next) + j / n_next,
                                ],
                                [(l + 1 / 2) * y0 + y1, (l + 1) * y0],
                                color=color,
                                lw=2 * scale,
                                alpha=alpha[l][j][i]
                                * self.mask[l][i].item()
                                * self.mask[l + 1][j].item(),
                            )
                        else:
                            plt.plot(
                                [1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N],
                                [l * y0, (l + 1 / 2) * y0 - y1],
                                color=color,
                                lw=2 * scale,
                                alpha=alpha[l][j][i] * alpha_mask,
                            )
                            plt.plot(
                                [
                                    1 / (2 * N) + id_ / N,
                                    1 / (2 * n_next) + j / n_next,
                                ],
                                [(l + 1 / 2) * y0 + y1, (l + 1) * y0],
                                color=color,
                                lw=2 * scale,
                                alpha=alpha[l][j][i] * alpha_mask,
                            )

            plt.xlim(0, 1)
            plt.ylim(-0.1 * y0, (neuron_depth - 1 + 0.1) * y0)

        # -- Transformation functions
        DC_to_FC = ax.transData.transform
        FC_to_NFC = fig.transFigure.inverted().transform
        # -- Take data coordinates and transform them to normalized figure coordinates
        DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))  # noqa

        plt.axis("off")

        # plot splines
        for l in range(neuron_depth - 1):
            n = width[l]
            for i in range(n):
                n_next = width[l + 1]
                N = n * n_next
                for j in range(n_next):
                    id_ = i * n_next + j
                    im = plt.imread(f"{folder}/sp_{l}_{i}_{j}.png")
                    left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0]
                    right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0]
                    bottom = DC_to_NFC([0, (l + 1 / 2) * y0 - y1])[1]
                    up = DC_to_NFC([0, (l + 1 / 2) * y0 + y1])[1]
                    newax = fig.add_axes(
                        [left, bottom, right - left, up - bottom]
                    )
                    # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
                    if not mask:
                        newax.imshow(im, alpha=alpha[l][j][i])
                    else:
                        # make sure to run model.prune() first to compute mask ###
                        newax.imshow(
                            im,
                            alpha=alpha[l][j][i]
                            * self.mask[l][i].item()
                            * self.mask[l + 1][j].item(),
                        )
                    newax.axis("off")

        if in_vars is not None:
            n = self.width[0]
            for i in range(n):
                plt.gcf().get_axes()[0].text(
                    1 / (2 * (n)) + i / (n),
                    -0.1,
                    in_vars[i],
                    fontsize=40 * scale,
                    horizontalalignment="center",
                    verticalalignment="center",
                )

        if out_vars is not None:
            n = self.width[-1]
            for i in range(n):
                plt.gcf().get_axes()[0].text(
                    1 / (2 * (n)) + i / (n),
                    y0 * (len(self.width) - 1) + 0.1,
                    out_vars[i],
                    fontsize=40 * scale,
                    horizontalalignment="center",
                    verticalalignment="center",
                )

        if title is not None:
            plt.gcf().get_axes()[0].text(
                0.5,
                y0 * (len(self.width) - 1) + 0.2,
                title,
                fontsize=40 * scale,
                horizontalalignment="center",
                verticalalignment="center",
            )

    def train_kan(
        self,
        dataset,
        opt="LBFGS",
        steps=100,
        log=1,
        lamb=0.0,
        lamb_l1=1.0,
        lamb_entropy=2.0,
        lamb_coef=0.0,
        lamb_coefdiff=0.0,
        update_grid=True,
        grid_update_num=10,
        loss_fn=None,
        lr=1.0,
        stop_grid_update_step=50,
        batch=-1,
        small_mag_threshold=1e-16,
        small_reg_factor=1.0,
        metrics=None,
        sglr_avoid=False,
        save_fig=False,
        in_vars=None,
        out_vars=None,
        beta=3,
        save_fig_freq=1,
        img_folder="./video",
        device="cpu",
    ):
        """
        training

        Args:
        -----
            dataset : dic
                contains dataset['train_input'], dataset['train_label'],
                dataset['test_input'], dataset['test_label']
            opt : str
                "LBFGS" or "Adam"
            steps : int
                training steps
            log : int
                logging frequency
            lamb : float
                overall penalty strength
            lamb_l1 : float
                l1 penalty strength
            lamb_entropy : float
                entropy penalty strength
            lamb_coef : float
                coefficient magnitude penalty strength
            lamb_coefdiff : float
                difference of nearby coefficits (smoothness) penalty strength
            update_grid : bool
                If True, update grid regularly before stop_grid_update_step
            grid_update_num : int
                the number of grid updates before stop_grid_update_step
            stop_grid_update_step : int
                no grid updates after this training step
            batch : int
                batch size, if -1 then full.
            small_mag_threshold : float
                threshold to determine large or small numbers (may want to
                apply larger penalty to smaller numbers)
            small_reg_factor : float
                penalty strength applied to small factors relative to large
                factos
            device : str
                device
            save_fig_freq : int
                save figure every (save_fig_freq) step

        Returns:
        --------
            results : dic
                results['train_loss'], 1D array of training losses (RMSE)
                results['test_loss'], 1D array of test losses (RMSE)
                results['reg'], 1D array of regularization

        Example
        -------
        >>> # for interactive examples, please see demos
        >>> from utils import create_dataset
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
        >>> model.plot()
        """

        def reg(acts_scale):

            def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
                return (x < th) * x * factor + (x > th) * (
                    x + (factor - 1) * th
                )

            reg_ = 0.0
            for i in range(len(acts_scale)):
                vec = acts_scale[i].reshape(
                    -1,
                )

                p = vec / torch.sum(vec)
                l1 = torch.sum(nonlinear(vec))
                entropy = -torch.sum(p * torch.log2(p + 1e-4))
                reg_ += (
                    lamb_l1 * l1 + lamb_entropy * entropy
                )  # both l1 and entropy

            # regularize coefficient to encourage spline to be zero
            for i in range(len(self.act_fun)):
                coeff_l1 = torch.sum(
                    torch.mean(torch.abs(self.act_fun[i].coef), dim=1)
                )
                coeff_diff_l1 = torch.sum(
                    torch.mean(
                        torch.abs(torch.diff(self.act_fun[i].coef)), dim=1
                    )
                )
                reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1

            return reg_

        pbar = tqdm(range(steps), desc="description", ncols=100)

        if loss_fn is None:
            loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
        else:
            loss_fn = loss_fn_eval = loss_fn

        grid_update_freq = int(stop_grid_update_step / grid_update_num)

        if opt == "Adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        elif opt == "LBFGS":
            optimizer = LBFGS(
                self.parameters(),
                lr=lr,
                history_size=10,
                line_search_fn="strong_wolfe",
                tolerance_grad=1e-32,
                tolerance_change=1e-32,
                tolerance_ys=1e-32,
            )

        results = {}
        results["train_loss"] = []
        results["test_loss"] = []
        results["reg"] = []
        if metrics is not None:
            for i in range(len(metrics)):
                results[metrics[i].__name__] = []

        if batch == -1 or batch > dataset["train_input"].shape[0]:
            batch_size = dataset["train_input"].shape[0]
            batch_size_test = dataset["test_input"].shape[0]
        else:
            batch_size = batch
            batch_size_test = batch

        global train_loss, reg_

        def closure():
            global train_loss, reg_
            optimizer.zero_grad()
            pred = self.forward(dataset["train_input"][train_id].to(device))
            if sglr_avoid:
                id_ = torch.where(~torch.isnan(torch.sum(pred, dim=1)))[0]
                train_loss = loss_fn(
                    pred[id_], dataset["train_label"][train_id][id_].to(device)
                )
            else:
                train_loss = loss_fn(
                    pred, dataset["train_label"][train_id].to(device)
                )
            reg_ = reg(self.acts_scale)
            objective = train_loss + lamb * reg_
            objective.backward()
            return objective

        if save_fig:
            if not os.path.exists(img_folder):
                os.makedirs(img_folder)

        for _ in pbar:

            train_id = np.random.choice(
                dataset["train_input"].shape[0], batch_size, replace=False
            )
            test_id = np.random.choice(
                dataset["test_input"].shape[0], batch_size_test, replace=False
            )

            if (
                _ % grid_update_freq == 0
                and _ < stop_grid_update_step
                and update_grid
            ):
                self.update_grid_from_samples(
                    dataset["train_input"][train_id].to(device)
                )

            if opt == "LBFGS":
                optimizer.step(closure)

            if opt == "Adam":
                pred = self.forward(
                    dataset["train_input"][train_id].to(device)
                )
                if sglr_avoid:
                    id_ = torch.where(~torch.isnan(torch.sum(pred, dim=1)))[0]
                    train_loss = loss_fn(
                        pred[id_],
                        dataset["train_label"][train_id][id_].to(device),
                    )
                else:
                    train_loss = loss_fn(
                        pred, dataset["train_label"][train_id].to(device)
                    )
                reg_ = reg(self.acts_scale)
                loss = train_loss + lamb * reg_
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            test_loss = loss_fn_eval(
                self.forward(dataset["test_input"][test_id].to(device)),
                dataset["test_label"][test_id].to(device),
            )

            if _ % log == 0:
                pbar.set_description(
                    "train loss: %.2e | test loss: %.2e | reg: %.2e "
                    % (
                        torch.sqrt(train_loss).cpu().detach().numpy(),
                        torch.sqrt(test_loss).cpu().detach().numpy(),
                        reg_.cpu().detach().numpy(),
                    )
                )

            if metrics is not None:
                for i in range(len(metrics)):
                    results[metrics[i].__name__].append(metrics[i]().item())

            results["train_loss"].append(
                torch.sqrt(train_loss).cpu().detach().numpy()
            )
            results["test_loss"].append(
                torch.sqrt(test_loss).cpu().detach().numpy()
            )
            results["reg"].append(reg_.cpu().detach().numpy())

            if save_fig and _ % save_fig_freq == 0:
                self.plot(
                    folder=img_folder,
                    in_vars=in_vars,
                    out_vars=out_vars,
                    title="Step {}".format(_),
                    beta=beta,
                )
                plt.savefig(
                    img_folder + "/" + str(_) + ".jpg",
                    bbox_inches="tight",
                    dpi=200,
                )
                plt.close()

        return results

    def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None):
        """
        pruning KAN on the node level. If a node has small incoming or
        outgoing connection, it will be pruned away.

        Args:
        -----
            threshold : float
                the threshold used to determine whether a node is small enough
            mode : str
                "auto" or "manual". If "auto", the thresold will be used to
                automatically prune away nodes. If "manual", active_neuron_id
                is needed to specify which neurons are kept
                (others are thrown away).
            active_neuron_id : list of id lists
                For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in
                the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden
                layer. Pruning input and output neurons is not supported yet.

        Returns:
        --------
            model2 : KAN
                pruned model

        Example
        -------
        >>> # for more interactive examples, please see demos
        >>> from utils import create_dataset
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
        >>> model.prune()
        >>> model.plot(mask=True)
        """
        mask = [
            torch.ones(
                self.width[0],
            )
        ]
        active_neurons = [list(range(self.width[0]))]
        for i in range(len(self.acts_scale) - 1):
            if mode == "auto":
                in_important = (
                    torch.max(self.acts_scale[i], dim=1)[0] > threshold
                )
                out_important = (
                    torch.max(self.acts_scale[i + 1], dim=0)[0] > threshold
                )
                overall_important = in_important * out_important
            elif mode == "manual":
                overall_important = torch.zeros(
                    self.width[i + 1], dtype=torch.bool
                )
                overall_important[active_neurons_id[i + 1]] = True
            mask.append(overall_important.float())
            active_neurons.append(torch.where(~overall_important)[0])
        active_neurons.append(list(range(self.width[-1])))
        mask.append(
            torch.ones(
                self.width[-1],
            )
        )

        self.mask = mask  # this is neuron mask for the whole model

        # update act_fun[l].mask
        for l in range(len(self.acts_scale) - 1):
            for i in range(self.width[l + 1]):
                if i not in active_neurons[l + 1]:
                    self.remove_node(l + 1, i)

        model2 = KAN(
            copy.deepcopy(self.width),
            self.grid,
            self.k,
            base_fun=self.base_fun,
        )
        model2.load_state_dict(self.state_dict())
        for i in range(len(self.acts_scale)):
            if i < len(self.acts_scale) - 1:
                model2.biases[i].weight.data = model2.biases[i].weight.data[
                    :, active_neurons[i + 1]
                ]

            model2.act_fun[i] = model2.act_fun[i].get_subset(
                active_neurons[i], active_neurons[i + 1]
            )
            model2.width[i] = len(active_neurons[i])
            model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(
                active_neurons[i], active_neurons[i + 1]
            )

        return model2

    def remove_edge(self, l, i, j):
        """
        remove activtion phi(l,i,j) (set its mask to zero)

        Args:
        -----
            l : int
                layer index
            i : int
                input neuron index
            j : int
                output neuron index

        Returns:
        --------
            None
        """
        self.act_fun[l].mask[j * self.width[l] + i] = 0.0

    def remove_node(self, l, i):
        """
        remove neuron (l,i) (set the masks of all incoming and outgoing
        activation functions to zero)

        Args:
        -----
            l : int
                layer index
            i : int
                neuron index

        Returns:
        --------
            None
        """
        self.act_fun[l - 1].mask[
            i * self.width[l - 1] + torch.arange(self.width[l - 1])
        ] = 0.0
        self.act_fun[l].mask[
            torch.arange(self.width[l + 1]) * self.width[l] + i
        ] = 0.0
        self.symbolic_fun[l - 1].mask[i, :] *= 0.0
        self.symbolic_fun[l].mask[:, i] *= 0.0

    def suggest_symbolic(
        self,
        l,
        i,
        j,
        a_range=(-10, 10),
        b_range=(-10, 10),
        lib=None,
        topk=5,
        verbose=True,
    ):
        """suggest the symbolic candidates of phi(l,i,j)

        Args:
        -----
            l : int
                layer index
            i : int
                input neuron index
            j : int
                output neuron index
            lib : dic
                library of symbolic bases. If lib = None, the global default
                library will be used.
            topk : int
                display the top k symbolic functions (according to r2)
            verbose : bool
                If True, more information will be printed.

        Returns:
        --------
            None

        Example
        -------
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
        >>> model = model.prune()
        >>> model(dataset['train_input'])
        >>> model.suggest_symbolic(0,0,0)
        function , r2
        sin , 0.9994412064552307
        gaussian , 0.9196369051933289
        tanh , 0.8608126044273376
        sigmoid , 0.8578218817710876
        arctan , 0.842217743396759
        """
        r2s = []

        if lib is None:
            symbolic_lib = SYMBOLIC_LIB
        else:
            symbolic_lib = {}
            for item in lib:
                symbolic_lib[item] = SYMBOLIC_LIB[item]

        for name, fun in symbolic_lib.items():
            r2 = self.fix_symbolic(
                l, i, j, name, a_range=a_range, b_range=b_range, verbose=False
            )
            r2s.append(r2.item())

        self.unfix_symbolic(l, i, j)

        sorted_ids = np.argsort(r2s)[::-1][:topk]
        r2s = np.array(r2s)[sorted_ids][:topk]
        topk = np.minimum(topk, len(symbolic_lib))
        if verbose:
            print("function", ",", "r2")
            for i in range(topk):
                print(
                    list(symbolic_lib.items())[sorted_ids[i]][0], ",", r2s[i]
                )

        best_name = list(symbolic_lib.items())[sorted_ids[0]][0]
        best_fun = list(symbolic_lib.items())[sorted_ids[0]][1]
        best_r2 = r2s[0]
        return best_name, best_fun, best_r2

    def auto_symbolic(
        self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1
    ):
        """
        automatic symbolic regression: using top 1 suggestion from
        suggest_symbolic to replace splines with symbolic activations

        Args:
        -----
            lib : None or a list of function names
                the symbolic library
            verbose : int
                verbosity

        Returns:
        --------
            None (print suggested symbolic formulas)

        Example 1
        ---------
        >>> # default library
        >>> from utils import create_dataset
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
        >>> >>> model = model.prune()
        >>> model(dataset['train_input'])
        >>> model.auto_symbolic()
        fixing (0,0,0) with sin, r2=0.9994837045669556
        fixing (0,1,0) with cosh, r2=0.9978033900260925
        fixing (1,0,0) with arctan, r2=0.9997088313102722

        Example 2
        ---------
        >>> # customized library
        >>> from utils import create_dataset
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
        >>> >>> model = model.prune()
        >>> model(dataset['train_input'])
        >>> model.auto_symbolic(lib=['exp','sin','x^2'])
        fixing (0,0,0) with sin, r2=0.999411404132843
        fixing (0,1,0) with x^2, r2=0.9962921738624573
        fixing (1,0,0) with exp, r2=0.9980258941650391
        """
        for l in range(len(self.width) - 1):
            for i in range(self.width[l]):
                for j in range(self.width[l + 1]):
                    if self.symbolic_fun[l].mask[j, i] > 0.0:
                        print(f"skipping ({l},{i},{j}) since already symbolic")
                    else:
                        name, fun, r2 = self.suggest_symbolic(
                            l,
                            i,
                            j,
                            a_range=a_range,
                            b_range=b_range,
                            lib=lib,
                            verbose=False,
                        )
                        self.fix_symbolic(l, i, j, name, verbose=verbose > 1)
                        if verbose >= 1:
                            print(f"fixing ({l},{i},{j}) with {name}, r2={r2}")

    def symbolic_formula(
        self, floating_digit=2, var=None, normalizer=None, simplify=False
    ):
        """
        obtain the symbolic formula

        Args:
        -----
            floating_digit : int
                the number of digits to display
            var : list of str
                the name of variables (if not provided, by default using
                ['x_1', 'x_2', ...])
            normalizer : [mean array (floats), varaince array (floats)]
                the normalization applied to inputs
            simplify : bool
                If True, simplify the equation at each step
                (usually quite slow), so set up False by default.

        Returns:
        --------
            symbolic formula : sympy function

        Example
        -------
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, \
            grid_eps=0.02)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
        >>> model = model.prune()
        >>> model(dataset['train_input'])
        >>> model.auto_symbolic(lib=['exp','sin','x^2'])
        >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, \
            update_grid=False);
        >>> model.symbolic_formula()
        """
        symbolic_acts = []
        x = []

        def ex_round(ex1, floating_digit=floating_digit):
            ex2 = ex1
            for a in sympy.preorder_traversal(ex1):
                if isinstance(a, sympy.Float):
                    ex2 = ex2.subs(a, round(a, floating_digit))
            return ex2

        # define variables
        if var is None:
            for ii in range(1, self.width[0] + 1):
                exec(f"x{ii} = sympy.Symbol('x_{ii}')")
                exec(f"x.append(x{ii})")
        else:
            x = [sympy.symbols(var_) for var_ in var]

        x0 = x

        if normalizer is not None:
            mean = normalizer[0]
            std = normalizer[1]
            x = [(x[i] - mean[i]) / std[i] for i in range(len(x))]

        symbolic_acts.append(x)

        for l in range(len(self.width) - 1):
            y = []
            for j in range(self.width[l + 1]):
                yj = 0.0
                for i in range(self.width[l]):
                    a, b, c, d = self.symbolic_fun[l].affine[j, i]
                    sympy_fun = self.symbolic_fun[l].funs_sympy[j][i]
                    try:
                        yj += c * sympy_fun(a * x[i] + b) + d
                    except:  # noqa
                        print(
                            "make sure all activations need to be converted \
                                to symbolic formulas first!"
                        )
                        return
                if simplify:
                    y.append(
                        sympy.simplify(yj + self.biases[l].weight.data[0, j])
                    )
                else:
                    y.append(yj + self.biases[l].weight.data[0, j])

            x = y
            symbolic_acts.append(x)

        self.symbolic_acts = [
            [
                ex_round(symbolic_acts[l][i])
                for i in range(len(symbolic_acts[l]))
            ]
            for l in range(len(symbolic_acts))
        ]

        out_dim = len(symbolic_acts[-1])  # noqa
        return [
            ex_round(symbolic_acts[-1][i])
            for i in range(len(symbolic_acts[-1]))
        ], x0

    def clear_ckpts(self, folder="./model_ckpt"):
        """
        clear all checkpoints

        Args:
        -----
            folder : str
                the folder that stores checkpoints

        Returns:
        --------
            None
        """
        if os.path.exists(folder):
            files = glob.glob(folder + "/*")
            for f in files:
                os.remove(f)
        else:
            os.makedirs(folder)

    def save_ckpt(self, name, folder="./model_ckpt"):
        """
        save the current model as checkpoint

        Args:
        -----
            name: str
                the name of the checkpoint to be saved
            folder : str
                the folder that stores checkpoints

        Returns:
        --------
            None
        """

        if not os.path.exists(folder):
            os.makedirs(folder)

        torch.save(self.state_dict(), folder + "/" + name)
        print("save this model to", folder + "/" + name)

    def load_ckpt(self, name, folder="./model_ckpt"):
        """
        load a checkpoint to the current model

        Args:
        -----
            name: str
                the name of the checkpoint to be loaded
            folder : str
                the folder that stores checkpoints

        Returns:
        --------
            None
        """
        self.load_state_dict(torch.load(folder + "/" + name))
